Variational Gradient Matching for Dynamical Systems: Dynamic Causal Modeling
,
Authors:
Nico Stephan Gorbach and Stefan Bauer, email: nico.gorbach@gmail.com
Contents:
Instructional code for the NIPS (2018) paper Scalable Variational Inference for Dynamical Systems by Nico S. Gorbach, Stefan Bauer and Joachim M. Buhmann. Please cite our paper if you use our program for a further publication. The derivations in this document are also given in the doctoral thesis https://www.research-collection.ethz.ch/handle/20.500.11850/261734 as well as in parts of Wenk et al. (2018).
Example dynamical system used in this code: Dynamic Causal Modeling (visual attention system) with three hidden neuronal- and 12 hidden hemodynamic states. The system is affected by given external inputs and the states are only indirectly observed through the BOLD signal change equation.
User Input: Simulation Settings
- Simulation ODEs
Input the ODEs "type" used to generate the data as a string. Options: 'nonlinear_forward_modulation_by_attention', 'forward_modulation_and_driven_by_attention'', 'forward_modulation_by_attention', 'backward_modulation_by_attention', 'backward_modulation_and_driven_by_attention', 'absent_modulation', 'absent_attention_input', 'absent_photic_input', 'driven_by_attention', 'photic_input'.
simulation.odes = 'forward_modulation_and_driven_by_attention';
- Observed States
Input a cell vector containing the symbols (characters) in the '_ODEs.txt' file. Eg: to observe deoxyhemoglobin content, blood volume and blood flow set simulation.observed_states = {'q_1','q_3','q_2','v_1','v_3','v_2','f_1','f_3','f_2'}).
simulation.observed_states = {};
- Final time for simulation
Input a positve real number:
simulation.final_time = 359*3.22;
- Observation noise
Input a function handle:
simulation.state_obs_variance = @(x)(repmat(bsxfun(@rdivide,var(x),5),size(x,1),1));
- Time interval between observations
Input a positive real number:
simulation.interval_between_observations = 0.1;
User Input: Estimation Settings
- Candidate ODEs
Input the ODEs "type" used for estimation as a string. Options: 'nonlinear_forward_modulation_by_attention', 'forward_modulation_and_driven_by_attention'', 'forward_modulation_by_attention', 'backward_modulation_by_attention', 'backward_modulation_and_driven_by_attention', 'absent_modulation', 'absent_attention_input', 'absent_photic_input', 'driven_by_attention', 'photic_input'.
candidate_odes = 'forward_modulation_and_driven_by_attention';
- Kernel parameters
Input a row vector of positive real numbers of size 1 x 2:
kernel.param = [10,0.2];
- Error variance on state derivatives (i.e. γ)
Input a row vector of positive real numbers of size 1 x number of ODEs:
state.derivative_variance = 6.*ones(11-3,1);
- Estimation times
Input a row vector of positive real numbers in ascending order:
time.est = 0:3.22:359*3.22;
Preliminary operations
close all; clc; addpath('VGM_functions')
Preprocessing for candidate ODEs
[symbols,ode,plot_settings,state,simulation,odes_path,coupling_idx,opt_settings] = ...
preprocessing_dynamic_causal_modeling (simulation,candidate_odes,state);
Simulate Trajectories
- Preprocessing for true ODEs
[symbols_true,ode_true] = preprocessing_dynamic_causal_modeling (simulation,simulation.odes,state);
Sample ODE parameters that lead to non-diverging trajectories:
non_diverging_trajectories = false; i = 0;
while ~non_diverging_trajectories
- Sample ODE parameters
non-selfinhibitory neuronal couplings (sampled uniformily in the interval
):
simulation.ode_param = -0.8 + (0.8-(-0.8)) * rand(1,length(symbols_true.param));
% simulation.ode_param = [0.46,0.13,0.39,0.26,0.5,0.26,0.1,1.25,-1,-1,-1]; % published ODE parameters (slightly modified from Stephan et al., 2008)
self-inhibitory neuronal couplings set to -1:
simulation.ode_param(end-2:end) = -1;
- Numerical integration
try
simulation_old = simulation;
[simulation,obs_to_state_relation,fig_handle,plot_handle] = simulate_state_dynamics_dcm(...
simulation,symbols_true,ode_true,time,plot_settings,state.ext_input,'plot');
non_diverging_trajectories = 1;
end
end
Mass Action Dynamical Systems
A deterministic dynamical system is represented by a set of K ordinary differential equations (ODEs) with model parameters
that describe the evolution of K states
such that:
A sequence of observations,
, is usually contaminated by measurement error which we assume to be normally distributed with zero mean and variance for each of the K states, i.e.
, with
. For N distinct time points the overall system may therefore be summarized as
where
and
is the k'th state sequence and
are the observations. Given the observations
and the description of the dynamical system (1), the aim is to estimate both state variables
and parameters
.
We consider only dynamical systems that are locally linear with respect to ODE parameters
and individual states
. Such ODEs include mass-action kinetics and are given by:
with
describing the state variables in each factor of the equation (i.e. the functions are linear in parameters and contain arbitrary large products of monomials of the states).
start timer
tic;
Prior on States and State Derivatives
Gradient matching with Gaussian processes assumes a joint Gaussian process prior on states and their derivatives:
with
Matching Gradients
Given the joint distribution over states and their derivatives (3) as well as the ODEs (2), we therefore have two expressions for the state derivatives:
where
and
is the error variance in the ODEs. Note that, in a deterministic system, the output of the ODEs
should equal the state derivatives
. However, in the first equation above we relax this contraint by adding stochasticity to the state derivatives
in order to compensate for a
potential model mismatch. The second equation above is obtained by deriving the conditional distribution for
from the joint distribution in equation (3). Equating the two expressions in the equations above we can eliminate the unknown state derivatives
:
with
.
[dC_times_invC,inv_C,A_plus_gamma_inv] = kernel_function(kernel,state,time.est);
Rewrite ODEs as Linear Combination in Parameters
Since, according to the mass action dynamics (equation 2), the ODEs are linear in the parameters
we can rewrite the ODEs in equation (2) as a linear combination in the parameters:
where matrices
and
are defined such that the ODEs
are expressed as a linear combination in
.
[ode_param.lin_comb.B,ode_param.lin_comb.b] = rewrite_odes_as_linear_combination_in_parameters(ode,symbols);
Posterior over ODE Parameters
Inserting (5) into (4) and solving for
yields:
where
denotes the pseudo-inverse of
. Since
is block diagonal we can rewrite the expression above as:
where we subsitute the Moore-Penrose inverse for the pseudo-inverse (i.e.
). We can therefore derive the posterior distribution over ODE parameters:
Rewrite Hemodynamic ODEs as Linear Combination in (monotonic functions of) Individual Hemodynamic States
- Deoxyhemoglobin content
Rewrite the BOLD signal change equation as a linear combination in a monotonic function of the deoxyhemoglobin content
:
[state.deoxyhemo.R,state.deoxyhemo.r] = rewrite_bold_signal_eqn_as_linear_combination_in_deoxyhemo(symbols);
- Blood volume
Rewrite the deoxyhemoglobin content ODE as a linear combination in a monotonic function of the blood volume
:
[state.vol.R,state.vol.r] = rewrite_deoxyhemo_ODE_as_linear_combination_in_vol(ode,symbols);
- Blood flow
Rewrite the blood volume ODE as a linear combination in a monotonic function of the blood flow
:
[state.flow.R,state.flow.r] = rewrite_vol_ODE_as_linear_combination_in_flow(ode,symbols);
- Vasosignalling
Rewrite the blood flow and vasoginalling ODEs as a linear combination in vasosignalling
:
[state.vaso.R,state.vaso.r] = rewrite_vaso_and_flow_odes_as_linear_combination_in_vaso(ode,symbols);
Rewrite Neuronal ODEs as Linear Combination in Individual Neuronal States
We rewrite the ODE(s)
as a linear combination in the individual state
:
where matrices
and
are defined such that the ODE
is expressed as a linear combination in the individual state
.
[state.neuronal.R,state.neuronal.r] = rewrite_odes_as_linear_combination_in_ind_neuronal_states(ode,symbols,coupling_idx.states);
Posterior over Individual States
Given the linear combination of the ODEs w.r.t. an individual state, we define the matrices
such that the expression
is rewritten as a linear combination in an individual state
:
Inserting (7) into (4) and solving for
yields:
where
denotes the pseudo-inverse of
. Since
is block diagonal we can rewrite the expression above as:
where we subsitute the Moore-Penrose inverse for the pseudo-inverse (i.e.
). We can therefore derive the posterior distribution over an individual state
:
with
denoting the set of all states except state
.
Mean-field Variational Inference
To infer the parameters
, we want to find the maximum a posteriori estimate (MAP):
However, the integral above is intractable due to the strong couplings induced by the nonlinear ODEs
which appear in the term
.
We use mean-field variational inference to establish variational lower bounds that are analytically tractable by decoupling state variables from the ODE parameters as well as decoupling the state variables from each other. Note that, since the ODEs described by equation (2) are locally linear, both conditional distributions
(equation (6)) and
(equation (8)) are analytically tractable and Gaussian distributed as mentioned previously. The decoupling is induced by designing a variational distribution
which is restricted to the family of factorial distributions:
The particular form of
and
are designed to be Gaussian distributed which places them in the same family as the true full conditional distributions. To find the optimal factorial distribution we minimize the Kullback-Leibler divergence between the variational and the true posterior distribution:
where
is the proxy distribution. The proxy distribution that minimizes the KL-divergence (10) depends on the true full conditionals and is given by:
Denoising BOLD Observations
We denoise the BOLD observation by standard GP regression.
bold_response.denoised_obs = denoising_BOLD_observations(simulation.bold_response{:,{'n_1','n_3','n_2'}},inv_C,symbols,simulation);
Fitting observations of state trajectories
We fit the observations of state trajectories by standard GP regression. The data-informed distribution
in euqation (9) can be determined analytically using Gaussian process regression with the GP prior
:
where
and
.
[mu,inv_sigma] = fitting_state_observations(inv_C,obs_to_state_relation,simulation,symbols);
Coordinate Ascent Variational Gradient Matching
We minimize the KL-divergence in equation (10) by coordinate descent (where each step is analytically tractable) by iterating between determining the proxy for the distribution over ODE parameters
and the proxies for the distribution over individual states
.
- Initialize the state estimation by the GP regression posterior
state.proxy.mean = array2table([time.est',mu],'VariableNames',['time',symbols.state_string]);
bold_response.obs_old = bold_response.denoised_obs;
ode_param.proxy.mean = zeros(length(symbols.param),1);
- Coordinate ascent
for i = 1:opt_settings.coord_ascent_numb_iter
Proxy for Hemodynamic States
Determine the proxies for the states, starting with deoxyhemoglobin followed by blood volume, blood flow and finally vasosignalling. The information flow in the hemodynamic system is shown in its factor graph below:
The model inversion in the hemodynmic factor graph above occurs locally w.r.t. individual states. Given the expression for the BOLD signal change equation, we invert the BOLD signal change equation analytically to determine the deoxyhemoglobin content
(1). The newly inferred deoxyhemoglobin content
influences the expression for the factor associated with the change in deoxyhemoglobin content
, which we subsequently invert analytically to infer the blood volume
(2). Thereafter, we infer the blood flow
(3) by inverting the factors associated with the change in blood volume
as well as vasosignalling
, followed by inferring vasosignalling
(4) by inverting the factors associated with blood flow induction
and vasosignalling
. Finally, the neuronal dynamics (5) are learned, in part, by inverting the factor associated with vasosignalling
. The typical trajectories of each of the states are shown (red) together with their iterative approximation (grey lines) obtained by graphical DCM.
- Proxy for deoxyhemolgobin content
Damping is required since we invert only the factor for the BOLD signal change equation w.r.t. a monotonic function of deoxyhemoglobin content
.
Undamped proxy:
state_proxy_undamped = proxy_for_deoxyhemoglobin_content(state.deoxyhemo,state.proxy.mean{:,symbols.state_string},...
bold_response.denoised_obs,symbols,A_plus_gamma_inv,opt_settings);
Damped proxy:
state.proxy.mean{:,{'q_1','q_3','q_2'}} = (1-opt_settings.damping) * state.proxy.mean{:,{'q_1','q_3','q_2'}} + ...
opt_settings.damping * state_proxy_undamped;
- Proxy for blood volume
Damping is required since we invert only the a subset of ODEs w.r.t. a monotonic function of blood volume
.
Undamped proxy:
state_proxy_undamped = proxy_for_blood_volume(state.vol,dC_times_invC,state.proxy.mean{:,symbols.state_string},...
ode_param.proxy.mean,symbols,A_plus_gamma_inv,opt_settings);
Damped proxy:
state.proxy.mean{:,{'v_1','v_3','v_2'}} = (1-opt_settings.damping) * state.proxy.mean{:,{'v_1','v_3','v_2'}} + ...
opt_settings.damping * state_proxy_undamped;
- Proxy for blood flow
Damping is required since we invert only the a subset of ODEs w.r.t. a mononic function of blood flow
.
Undamped proxy:
state_proxy_undamped = proxy_for_blood_flow(state.flow,dC_times_invC,state.proxy.mean{:,symbols.state_string},...
ode_param.proxy.mean,symbols,A_plus_gamma_inv,opt_settings);
Damped proxy:
state.proxy.mean{:,{'f_1','f_3','f_2'}} = (1-opt_settings.damping) * state.proxy.mean{:,{'f_1','f_3','f_2'}} + ...
opt_settings.damping * state_proxy_undamped;
- Proxy for vasosignalling
No damping is required because we invert all ODEs w.r.t. vasosingalling
.
state.proxy.mean{:,{'s_1','s_3','s_2'}} = proxy_for_vasosignalling(state.vaso,dC_times_invC,...
state.proxy.mean{:,symbols.state_string},ode_param.proxy.mean,symbols,A_plus_gamma_inv,opt_settings);
Proxy for Neuronal States
Determine the proxies for the neuronal states. An example of the information flow in the neuronal part of the nonlinear forward modulating (nonlin_fwd_mod) is shown in its factor graph below:
In the neuronal factor graph (for the nonlinear forwad modulation) above each individual state appears linear in every factor in the neuronal model. We can therefore analytically invert every factor to determine the neuronal state. The typical trajectories of each of the states are shown (red) together with their iterative approximation (grey lines) obtained by variational gradient matching.
No damping is required because we invert all ODEs w.r.t. neuronal populations
.
state.proxy.mean{:,{'n_1','n_3','n_2'}} = proxy_for_neuronal_populations(state.neuronal,...
state.proxy.mean{:,symbols.state_string},ode_param.proxy.mean',dC_times_invC,...
coupling_idx.states,symbols,A_plus_gamma_inv,opt_settings);
Keep initial value at zero:
state_idx = cellfun(@(x) ~strcmp(x(1),'u'),symbols.state_string);
state.proxy.mean{:,symbols.state_string(state_idx)} = bsxfun(@minus,state.proxy.mean{:,symbols.state_string(state_idx)},...
state.proxy.mean{1,symbols.state_string(state_idx)});
Proxy for ODE parameters
Expanding the proxy distribution in equation (11) for
yields:
where we substitute
with its density given in equation (6).
No damping is required because we invert all ODEs w.r.t. neuronal couplings
.
if i>200 || i==opt_settings.coord_ascent_numb_iter
[ode_param.proxy.mean,ode_param.proxy.inv_cov] = proxy_for_ode_parameters(...
state.proxy.mean{:,symbols.state_string},dC_times_invC,ode_param.lin_comb,...
symbols,A_plus_gamma_inv,opt_settings);
end
Intercept due to Confounding Effects
The BOLD response is given by:
where
are the BOLD observations,
is the BOLD signal change equation and the matrix
is given. The intercept is determined by a minimum least squares estimator:
bold_signal_change = bold_signal_change_eqn(state.proxy.mean{:,{'v_1','v_3','v_2'}},state.proxy.mean{:,{'q_1','q_3','q_2'}});
intercept = simulation.X0 * (simulation.X0' * simulation.X0)^(-1) * simulation.X0' * (bold_response.obs_old-bold_signal_change);
bold_response.denoised_obs = bold_response.obs_old + intercept;
Intermediate Results:
if i==1 || ~mod(i,2)
plot_results(fig_handle,state.proxy,simulation,ode_param.proxy.mean,plot_handle,symbols,plot_settings,'not_final');
end
end
Numerical Integration with Estimated ODE Parameters
See whether we actually fit the BOLD responses well. Curves are shown in black.
simulation2 = simulation_old; simulation2.ode_param = ode_param.proxy.mean';
[simulation2,obs_to_state_relation] = simulate_state_dynamics_dcm(simulation2,symbols,ode,time,...
plot_settings,state.ext_input,'no plot');
state.proxy.num_int = simulation2.state;
Final Results
plot_results(fig_handle,state.proxy,simulation,ode_param.proxy.mean,plot_handle,symbols,...
plot_settings,'final',simulation2.bold_response_true,simulation.odes,candidate_odes);
Time Taken
disp(['time taken: ' num2str(toc) ' seconds'])
References
Gorbach, N.S. Validation and Inference of Structural Connectivity and Neural Dynamics with MRI data. 2018. ETH Zürich Doctoral Thesis. https://www.research-collection.ethz.ch/handle/20.500.11850/261734.
Gorbach, N.S. , Bauer, S. and Buhmann, J.M., Scalable Variational Inference for Dynamical Systems. 2017a. Neural Information Processing Systems (NIPS). Link to NIPS paper here and arxiv paper here.
Bauer, S. , Gorbach, N.S. and Buhmann, J.M., Efficient and Flexible Inference for Stochastic Differential Equations. 2017b. Neural Information Processing Systems (NIPS). Link to NIPS paper here.
Wenk, P., Gotovos, A., Bauer, S., Gorbach, N.S., Krause, A. and Buhmann, J.M., Fast Gaussian Process Based Gradient Matching for Parameters Identification in Systems of Nonlinear ODEs. 2018. In submission to Conference on Uncertainty in Artificial Intelligence (UAI). Link to arxiv paper here.
Calderhead, B., Girolami, M. and Lawrence. N.D., 2002. Accelerating Bayesian inference over nonlinear differential equation models. In Advances in Neural Information Processing Systems (NIPS) . 22.
The authors in bold font have contributed equally to their respective papers.